import numpy as np

import gurobipy as gp
from gurobipy import GRB
# from xbcausalforest import XBCF
from DataGeneration.DataGeneration import GenerateData
from DataGeneration.PosteriorGeneration import ComputePosterior, ComputePosteriorXBCF
from joblib import Parallel, delayed

X_evaluation = np.array([np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))[0].reshape((-1)),
                         np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))[1].reshape((-1))]).swapaxes(0, 1)
# X_evaluation = None
maketable = lambda t1,t2,t3: np.stack([t1,t2,t3])

def Gub_constraint(Existing_policy, EA, NIP, N, X, epsilon):
    m = gp.Model('conditional')
    m.Params.LogToConsole = 0
    b = [m.addMVar(1, lb=0.001, ub=1000, vtype=GRB.CONTINUOUS), m.addMVar(1, lb=0.001, ub=1000, vtype=GRB.CONTINUOUS)]
    y = m.addMVar(N, vtype=GRB.BINARY)
    m.addConstrs(y[i] * (b[1] * X[i, 1] + b[0] * X[i, 0] - 1) >= 0 for i in range(N))
    m.addConstrs((1 - y[i]) * (b[1] * X[i, 1] + b[0] * X[i, 0] - 1) <= 0 for i in range(N))
    EPI = y @ EA - Existing_policy @ EA
    Meat = np.diag(NIP) / N
    Risk = y @ Meat @ y - 2 * Existing_policy @ Meat @ y + Existing_policy @ Meat @ Existing_policy
    m.setObjective(EPI, GRB.MAXIMIZE)
    m.addConstr(Risk <= epsilon, name='risk')
    y.Start = Existing_policy
    m.optimize()
    return m
def Simulate_con(myseed=None,sigma=None):
    k = 2
    length_scale = None
    sigma_kernel = None
    Optimal_policy = lambda X, eyx: 1 * (eyx(X, 1) > eyx(X, 0))
    DG = GenerateData(Ydist='normal', eyx=eyx1, sigma=sigma, collecting_rule=collecting_rule, k=2, N=N, seed=myseed)
    X, A, Ey, Y = DG.GenerateData()
    paramlist = {'length_scale': length_scale, 'sigma_kernel': sigma_kernel}
    EPI, NIP = ComputePosteriorXBCF(X, A, Y, paramlist, trt_price=0)
    optpolicy = lambda X: Optimal_policy(X, eyx1)
    m_all = [Gub_constraint(basepolicy(X), EPI, NIP, N, X, epsilon) for epsilon in epsilon_all]
    newpolicy_crust = lambda X, m: 1 * ((m.X[0] * X[:, 0] + m.X[1] * X[:, 1] - 1) >= 0)
    # X_evaluation = X
    NIP_all = [np.logical_and(optpolicy(X_evaluation) == basepolicy(X_evaluation),
                              optpolicy(X_evaluation) != newpolicy_crust(X_evaluation, m)).mean() for m in m_all]
    AUI_all = [eyx1(X_evaluation, newpolicy_crust(X_evaluation, m)).mean() for m in m_all]
    return np.array([NIP_all, AUI_all])



mysigmalist = [1,2,3]
myNlist = [50,100,200,500]
for sigma in mysigmalist:
    for N in myNlist:
        epsilon_all = np.array([1e-3, 2e-3, 3.3e-3, 5e-3, 6.7e-3, 8e-3, 1e-2, 2e-2, 3.3e-2, 5e-2, 6.7e-2, 8e-2,
                                1e-1,0.125,0.15,0.175, 2e-1, 3.3e-1, 6.7e-1, 8e-1, 1])
        eyx1 = lambda X, A: (+X[:, 0] + X[:, 1] + 4 * A * (X[:, 0].__abs__())*(X[:,1].__abs__())*((X[:,1]>0)*(X[:,0]>0)-0.5))
        collecting_rule = lambda X: np.random.binomial(1,0.5,size=X.shape[0])  # return the probability of having action 1
        basepolicy = lambda X: np.zeros(shape=(X.shape[0]))
        temp3 = Parallel(n_jobs=12)(delayed(Simulate_con)(sigma=sigma) for i in range(5))
        result3 = np.array(temp3)
        mydataname = "_".join(["XBCF",str(N),str(sigma),'.npy'])
        np.save('Data/Continuous/'+mydataname, result3)
